# https://gitee.com/yueyinqiu5990/tj12413601/blob/master/assignment4/question2/datasets.py
import csv

import torch.utils.data

import torch_plus


class EpsDataset(torch.utils.data.Dataset):
    def __init__(self):
        self._data: list[tuple[torch.Tensor, torch.Tensor]] = []
        with open("./data/problem2.csv", encoding="utf8", newline="") as file:
            reader = csv.reader(file)
            _ = next(reader)
            for row in reader:
                omega = torch_plus.as_tensor(float(row[0]))
                r = torch_plus.as_tensor(float(row[1]))
                i = torch_plus.as_tensor(float(row[2]))
                self._data.append((
                    omega,
                    torch.stack([r, i])
                ))

    def __getitem__(self, item):
        return self._data[item]

    def __len__(self):
        return len(self._data)
